# ------------------------------------------------------------------------------
# Create Tables 1 and 2 (descriptive statistics)
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 07_descriptive_tbl.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(dplyr)
# library(gt) # gt() make table
library(data.table)
library(testit) # assert
library(xtable) # make table w latex formatting
library(glue) # glue
library(scales) # comma_format()

u <- modules::use(here("lib", "util.R"))
temp <- here("code", "02_prep_and_summarize_cohort", "temp")

# Helper Functions -------------------------------------------------------------
get_mean <- function(varname, dt, dig = 3) {
  mean(dt[[varname]]) %>%
    round(digits = dig)
}
get_mean_se <- function(varname, dt, dig = 3, dig_se = 3) {
  mean <- mean(dt[[varname]]) %>%
    round(digits = dig) %>%
    as.character()
  se_raw <- (sd(dt[[varname]]) / sqrt(nrow(dt)))
  se <- ifelse(
    se_raw < 0.001, "<0.001",
    as.character(se_raw %>% round(digits = dig_se))
  )
  se <- glue("({se})")

  return(c(mean, se))
}
get_med <- function(varname, dt) {
  x <- dt[[varname]]
  med <- median(x, na.rm = TRUE) %>%
    round(digits = 0)
  q25 <- quantile(x, 0.25, na.rm = TRUE) %>% round(digits = 0)
  q75 <- quantile(x, 0.75, na.rm = TRUE) %>% round(digits = 0)

  med_iqr <- glue("{med} [{q25},{q75}]")
}
get_mean_se_pos_tn <- function(dt, dig = 3) {
  total_tn <- sum(dt$maxtrop_sameday, na.rm = TRUE)
  num_pos <- sum(dt$maxtrop_sameday > 0, na.rm = TRUE)
  mean_pos <- (total_tn / num_pos) %>% round(dig)
  se_raw_pos <- mean_pos * (1 - mean_pos) / sqrt(num_pos)
  se_pos <- ifelse(
    se_raw_pos < 0.001, "(<0.001)",
    as.character(se_raw_pos %>% round(digits = dig))
  )
  return(c(as.character(mean_pos), se_pos))
}

mean_se_notna <- function(varname, dt, dig = 3) {
  mean <- mean(!is.na(dt[[varname]])) %>%
    round(digits = dig) %>%
    as.character()
  se_raw <- sd(!is.na(dt[[varname]])) / sqrt(nrow(dt))
  se <- ifelse(
    se_raw < 0.001, "(<0.001)",
    as.character(se_raw %>% round(digits = dig))
  )
  return(c(mean, se))
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
overnight_lab <- ""
paths <- read_yaml(here("lib", "filepaths.yml"))
full_cohort <- readRDS(glue(paths$analysis$full_cohort)) %>%
  mutate(maxtrop_sameday_pos = replace_na(maxtrop_sameday > 0, FALSE))
demos <- read_csv(paths$features$dem) %>%
  setnames("dem_ed_enc_id", "ed_enc_id")
death <- readRDS(paths$analysis$demographics) %>%
  select(ed_enc_id, death_030_day, death_365_day)
troponin <- readRDS(paths$analysis$troponin)
longterm_outcomes <- readRDS(paths$analysis$longterm_outcomes)

risk_vars <- c(
  "athero", "diabetes", "hypertension",
  "cholesterol", "athero_rf_custom_zch"
)
risk_factors <- readRDS(paths$analysis$risk_factors) %>%
  select(ed_enc_id, all_of(risk_vars))

cohort <- full_cohort %>%
  u$safe_left_join(demos) %>%
  u$safe_left_join(risk_factors) %>%
  u$safe_left_join(death) %>%
  u$safe_left_join(longterm_outcomes) %>%
  filter(!exclude)

cohort_tn_ami <- full_cohort %>%
  # u$safe_left_join(demos) %>%
  # u$safe_left_join(risk_factors) %>%
  # u$safe_left_join(death) %>%
  # u$safe_left_join(mace_outcomes_p1) %>%
  filter(
    !excl_flag_c_int & !excl_flag_chronic & !excl_flag_death &
    age_at_admit < 80
  ) %>%
  mutate(maxtrop_sameday_pos = maxtrop_sameday > 0) %>%
  mutate(maxtrop_sameday_pos = replace_na(maxtrop_sameday_pos, FALSE))

# Subset data ------------------------------------------------------------------
message("Subsetting data into tested/untested...")
race_vars <- names(cohort)[grepl("dem_race", names(cohort))]
income_vars <- names(cohort)[grepl("dem_agi", names(cohort))]

patients_all <- cohort %>%
  select(
    ptid, test_010_day, all_of(race_vars), dem_sex_female,
    all_of(income_vars), all_of(risk_vars)
  ) %>%
  group_by(ptid) %>%
  summarize(test_010_day = any(test_010_day)) %>%
  ungroup()
patients_t <- patients_all %>%
  filter(test_010_day)
patients_u <- patients_all %>%
  filter(!test_010_day)
assert(
  "Untested patients + tested patients = patients",
  nrow(patients_all) == nrow(patients_t) + nrow(patients_u)
)

enc_all <- cohort
enc_t <- cohort %>% filter(test_010_day)
enc_u <- cohort %>% filter(!test_010_day)

enc_tn_ami_all <- cohort_tn_ami
enc_tn_ami_t <- cohort_tn_ami %>% filter(test_010_day)
enc_tn_ami_u <- cohort_tn_ami %>% filter(!test_010_day)

assert(
  "Untested encs + tested encs = encs",
  nrow(enc_all) == nrow(enc_t) + nrow(enc_u)
)

# Build Table ------------------------------------------------------------------
message("Building table...")
dt <- tibble(
  var_raw = c("n_patients", "n_encounters"),
  measure = c("count", "count"),
  All = c(comma_format()(nrow(patients_all)), comma_format()(nrow(enc_all))),
  Tested = c(comma_format()(nrow(patients_t)), comma_format()(nrow(enc_t))),
  Untested = c(comma_format()(nrow(patients_u)), comma_format()(nrow(enc_u)))
)

race_dt <- tibble(
  var_raw = c(race_vars) %>% rep(each = 2),
  measure = rep(c("mean", "se"), length(race_vars)),
  All = map(race_vars, get_mean_se, dt = enc_all) %>% unlist(),
  Tested = map(race_vars, get_mean_se, dt = enc_t) %>% unlist(),
  Untested = map(race_vars, get_mean_se, dt = enc_u) %>% unlist()
)

# race_dt_check <- tibble(
#   var_raw = rep(race_vars, each = 2),
#   All = map(race_vars, get_mean_se, dt = enc_all) %>% unlist(),
#   Tested = map(race_vars, get_mean_se, dt = enc_t) %>% unlist(),
#   Untested = map(race_vars, get_mean_se, dt = enc_u) %>% unlist()
# )

age_dt <- tibble(
  var_raw = c("age_mean", "age_mean", "age_median_iqr"),
  measure = c("mean", "se", "median"),
  All = c(
    get_mean_se("age_at_admit", enc_all, dig = 0),
    get_med("age_at_admit", enc_all)
  ),
  Tested = c(
    get_mean_se("age_at_admit", enc_t, dig = 0),
    get_med("age_at_admit", enc_t)
  ),
  Untested = c(
    get_mean_se("age_at_admit", enc_u, dig = 0),
    get_med("age_at_admit", enc_u)
  )
)


sex_dt <- tibble(
  var_raw = c("female", "female"),
  measure = c("mean", "se"),
  All = get_mean_se("dem_sex_female", enc_all),
  Tested = get_mean_se("dem_sex_female", enc_t),
  Untested = get_mean_se("dem_sex_female", enc_u)
)

# dist_dt <- tibble(
#   var_raw = c("miles_to_bwh"),
#   All = get_med("dem_dist_BWH_miles", enc_all),
#   Tested = get_med("dem_dist_BWH_miles", enc_t),
#   Untested = get_med("dem_dist_BWH_miles", enc_u)
# )
#
# demos_dt <- age_dt %>%
#   rbind(sex_dt) %>%
#   rbind(race_dt) %>%
#   mutate()

risk_dt <- tibble(
  var_raw = c(risk_vars) %>% rep(each = 2),
  measure = rep(c("mean", "se"), length(risk_vars)),
  All = map(risk_vars, get_mean_se, dt = enc_all) %>% unlist(),
  Tested = map(risk_vars, get_mean_se, dt = enc_t) %>% unlist(),
  Untested = map(risk_vars, get_mean_se, dt = enc_u) %>% unlist()
)

summary_dt <- dt %>%
  rbind(age_dt) %>%
  rbind(sex_dt) %>%
  rbind(race_dt) %>%
  rbind(risk_dt)

# Misc other rates -------------------------------------------------------------
test_rate <- get_mean_se("test_010_day", enc_all)
stress_rate <- get_mean_se("stress_010_day", enc_all)
cath_rate <- get_mean_se("cath_010_day", enc_all)

testing_rates_dt <- tibble(
  var_raw = c("test_rate", "cath_rate", "stress_rate") %>% rep(each = 2),
  measure = c("mean", "se") %>% rep(3),
  All = c(test_rate, cath_rate, stress_rate),
  Tested = c("-", "-", "-") %>% rep(2),
  Untested = c("-", "-", "-") %>% rep(2)
)

# Physician Suspicion vars
tn_dt <- tibble(
  var_raw = c("tn_rate", "mean_pos_tn") %>% rep(each = 2),
  measure = c("mean", "se") %>% rep(2),
  All = c(mean_se_notna("maxtrop_sameday", enc_tn_ami_all), get_mean_se_pos_tn(enc_tn_ami_all)),
  Tested = c(mean_se_notna("maxtrop_sameday", enc_tn_ami_t), get_mean_se_pos_tn(enc_tn_ami_t)),
  Untested = c(mean_se_notna("maxtrop_sameday", enc_tn_ami_u), get_mean_se_pos_tn(enc_tn_ami_u))
)

sameday_vars <- c("maxtrop_sameday_pos", "ami_day_of")
sameday_outcomes_dt <- tibble(
  var_raw = sameday_vars %>% rep(each = 2),
  measure = c("mean", "se") %>% rep(length(sameday_vars)),
  All = map(sameday_vars, get_mean_se, dt = enc_tn_ami_all) %>% unlist(),
  Tested = map(sameday_vars, get_mean_se, dt = enc_tn_ami_t) %>% unlist(),
  Untested = map(sameday_vars, get_mean_se, dt = enc_tn_ami_u) %>% unlist()
)

# Outcomes
outcome_vars <- c(
  "death_030_day", "macetrop_030_pos", "macetrop_pos_or_death_030",
  "death_365_day", "stent_or_cabg_010_day",
  "stent_010_day", "cabg_010_day", "has_ecg"
)

outcomes_dt <- tibble(
  var_raw = outcome_vars %>% rep(each = 2),
  measure = c("mean", "se") %>% rep(length(outcome_vars)),
  All = map(outcome_vars, get_mean_se, dt = enc_all) %>% unlist(),
  Tested = map(outcome_vars, get_mean_se, dt = enc_t) %>% unlist(),
  Untested = map(outcome_vars, get_mean_se, dt = enc_u) %>% unlist()
)

suspicion_dt <- tn_dt %>%
  rbind(sameday_outcomes_dt) %>%
  rbind(testing_rates_dt) %>%
  rbind(outcomes_dt)

# Labels -----------------------------------------------------------------------
message("Labeling variables...")
label_xwalk <- read_csv(here("lib", "summary_tbl_labels.csv"))

labeled_summary <- summary_dt %>%
  u$safe_left_join(label_xwalk) %>%
  mutate(Variable = ifelse(measure == "se", "", Variable)) %>%
  select(Variable, All, Tested, Untested)

labeled_suspicion <- suspicion_dt %>%
  u$safe_left_join(label_xwalk) %>%
  mutate(Variable = ifelse(measure == "se", "", Variable)) %>%
  select(Variable, All, Tested, Untested)


# Format Table -----------------------------------------------------------------
message("Saving table...")
xt_summary <- xtable(labeled_summary)
xt_suspicion <- xtable(labeled_suspicion)

# tab_1 <- dt %>%
#          gt(rowname_col = "Variable", groupname_col = "group") %>%
#          fmt_number(columns = vars(All, Tested, Untested), decimals = 2) %>%
#          fmt_number(columns = vars(All, Tested, Untested),
#                     rows = c("n_patients", "n_encounters"), decimals = 0)

sani_function <- function(x) {
  x <- gsub("[", "{[}", x, fixed = TRUE) %>%
    gsub("]", "{]}", ., fixed = TRUE) %>%
    gsub("%", "\\%", ., fixed = TRUE) %>%
    gsub("<", "$<$", ., fixed = TRUE)
}

# Save -------------------------------------------------------------------------
message("Saving table...")
print(
  xt_summary,
  type = "latex", file = file.path(temp, "desc_stats_summary.tex"),
  include.rownames = FALSE, sanitize.text.function = sani_function
)
print(
  xt_suspicion,
  type = "latex", file = file.path(temp, "desc_stats_suspicion.tex"),
  include.rownames = FALSE, sanitize.text.function = sani_function
)
# gtsave(tab_1, file.path(temp, "desc_stats.html"))

message("Done.")
